https://www.lintcode.com/problem/binary-tree-pruning/

/**
 * Definition of TreeNode:
 * public class TreeNode {
 *     public int val;
 *     public TreeNode left, right;
 *     public TreeNode(int val) {
 *         this.val = val;
 *         this.left = this.right = null;
 *     }
 * }
 */

public class Solution {
    /**
     * @param root: the root
     * @return: the same tree where every subtree (of the given tree) not containing a 1 has been removed
     */
    public TreeNode pruneTree(TreeNode root) {
        // Write your code here
        if (root != null) {
            root.left = pruneTree(root.left);
            root.right = pruneTree(root.right);
            if ((root.left == null) && (root.right == null) && (root.val == 0)) {
                return null;
            }
        }
        return root;
    }
}